import torch

def imshift(img_in,s1,s2,device_1):
    pi=torch.pi
    img_in=img_in.squeeze()
    shape=img_in.shape
    n2=shape[0]
    n1=shape[1]
    f1=torch.arange(-n1 / 2, n1 / 2, 1).to(device_1)
    f2=torch.arange(-n2 / 2, n2 / 2, 1).to(device_1)
    u1, u2 = torch.meshgrid(f1, f2)
    img_out = torch.fft.ifft2(
                torch.fft.fftshift(
                  torch.mul(
                    torch.fft.fftshift(torch.fft.fft2(img_in)),torch.exp(-1j * 2 * pi * (s1 * u1 / n1 + s2 * u2 / n2))
                  )
                )
    )
    return img_out